import sys

import os
import math
import lightning.pytorch as pl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Any
import glob

import torch
from torch.optim.lr_scheduler import LambdaLR
from torchmetrics.functional import accuracy
import torch.nn.functional as F
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from dataset.util_tools.optim_factory import get_parameter_groups
from dataset.util_tools.utils import MetricLogger

def get_block_layer(block_layers):
    if block_layers < 12:
        Block_layers = [block_layers]
    elif block_layers == 12:
        Block_layers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
    elif block_layers == 21:
        Block_layers = [0, 1]
    elif block_layers == 22:
        Block_layers = [2, 3]
    elif block_layers == 23:
        Block_layers = [4, 5]
    elif block_layers == 24:
        Block_layers = [6, 7]
    elif block_layers == 25:
        Block_layers = [8, 9]
    elif block_layers == 26:
        Block_layers = [10, 11]

    elif block_layers == 31:
        Block_layers = [0, 1, 2]
    elif block_layers == 32:
        Block_layers = [3, 4, 5]
    elif block_layers == 33:
        Block_layers = [6, 7, 8]
    elif block_layers == 34:
        Block_layers = [9, 10, 11]

    elif block_layers == 41:
        Block_layers = [1, 2, 3, 4]    
    elif block_layers == 42:
        Block_layers = [4, 5, 6, 7]
    elif block_layers == 43:
        Block_layers = [8, 9, 10, 11]

    elif block_layers == 61:
        Block_layers = [0, 1, 2, 3, 4, 5]
    elif block_layers == 62:
        Block_layers = [6, 7, 8, 9, 10, 11]
    
    elif block_layers == 99:
        Block_layers = []

    return Block_layers


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
                     start_warmup_value=0, warmup_steps=-1):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if (warmup_steps > 0):
        warmup_iters = warmup_steps
    print("Set warmup steps = %d" % warmup_iters)
    if (warmup_epochs > 0):
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = np.array(
        [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / len(iters))) for i in iters])

    schedule = np.concatenate((warmup_schedule, schedule))

    assert len(schedule) == epochs * niter_per_ep
    return schedule

def save_lr_image(data, path, name):
    x = np.arange(len(data))
    plt.figure(figsize=(10, 6))
    plt.plot(x, data, label=name)
    plt.title(name)
    plt.xlabel('Index')
    plt.ylabel(name)
    plt.legend()
    plt.grid(True)

    if not os.path.exists(path):
        os.makedirs(path)
    # 그래프를 파일로 저장
    plt.savefig(f'{path}{name}.png')

class lightening_module(pl.LightningModule):
    def __init__(
            self,
            model,
            num_classes,
            mixup_fn,
            assigner,
            len_train,
            save_directory,
            args,
            test_only: bool = False,
            visualize: bool = None,
    ):  
        super().__init__()
        self.save_hyperparameters(args, ignore=['model'])
        self.model = model
        self.num_classes = num_classes
        self.mixup_fn = mixup_fn
        self.assigner = assigner
        self.len_train = len_train
        self.args = args
        
        if test_only == False:
            for name, param in self.model.named_parameters():
                if 'merge' in name or name.startswith("fc") or 'adapter' in name:
                    param.requires_grad = True
                    print(name)
                else:
                    param.requires_grad = False
                # param.data = param.data.half()


            if mixup_fn is not None:
                # smoothing is handled with mixup label transform
                self.train_loss = SoftTargetCrossEntropy()
            elif args.smoothing > 0.:
                self.train_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
            else:
                self.train_loss = torch.nn.CrossEntropyLoss()
            self.loss_func = torch.nn.CrossEntropyLoss()


            if self.args.weight_decay_end is None:
                self.args.weight_decay_end = self.args.weight_decay

            self.wd_scheduler = cosine_scheduler(self.args.weight_decay, self.args.weight_decay_end, self.args.epochs, self.len_train)

            self.lr_scheduler = cosine_scheduler(
                self.args.lr, self.args.min_lr, self.args.epochs, self.len_train // len(args.device),
                warmup_epochs=self.args.warmup_epochs, warmup_steps=self.args.warmup_steps
            )
            save_lr_image(self.lr_scheduler, save_directory, 'Learning rate')
            save_lr_image(self.wd_scheduler, save_directory, 'Weight decay')
            self.batch_step  = 0

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]
        if self.mixup_fn:
            x, y = self.mixup_fn(x, y)
        y_hat = self(x)
        loss = self.train_loss(y_hat, y)

        self.log("train_loss", loss, prog_bar=True, sync_dist=True, batch_size=self.args.batch_size)
        if not self.mixup_fn:
            y_pred = torch.softmax(y_hat, dim=-1)
            self.log("train_acc", accuracy(y_pred, y, task="multiclass", num_classes=self.num_classes), prog_bar=True, sync_dist=True, batch_size=self.args.batch_size)
            self.log("train_acc5", accuracy(y_pred, y, task="multiclass", num_classes=self.num_classes, top_k=5), prog_bar=True, sync_dist=True, batch_size=self.args.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]
        y_hat = self(x)
        loss = self.loss_func(y_hat, y)
        y_pred = torch.softmax(y_hat, dim=-1)
        
        self.log("val_loss", torch.mean(loss), prog_bar=True, sync_dist=True, batch_size=self.args.batch_size)
        self.log("val_acc", accuracy(y_pred, y, task="multiclass", num_classes=self.num_classes), prog_bar=True, sync_dist=True, batch_size=self.args.batch_size)
        self.log("val_acc5", accuracy(y_pred, y, task="multiclass", num_classes=self.num_classes, top_k=5), prog_bar=True, sync_dist=True, batch_size=self.args.batch_size)
        return loss

    def on_train_epoch_end(self) -> None:
        self.log("lr", self.optimizers().optimizer.param_groups[0]["lr"], on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=self.args.batch_size)
        self.log("weight_decay", self.optimizers().optimizer.param_groups[0]["weight_decay"], on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=self.args.batch_size)

    def on_train_batch_start(self, batch, batch_idx) -> None:
        for i in range(len(self.optimizers().optimizer.param_groups)):
            if self.wd_scheduler is not None and self.optimizers().optimizer.param_groups[i]["weight_decay"] > 0:
                self.optimizers().optimizer.param_groups[i]["weight_decay"] = self.wd_scheduler[self.batch_step]

            self.optimizers().optimizer.param_groups[i]["lr"] = self.lr_scheduler[self.batch_step] * self.optimizers().optimizer.param_groups[i]["lr_scale"]
        self.batch_step += 1
        

    def configure_optimizers(self):
        if self.args.weight_decay:
            parameters = get_parameter_groups(self.model, self.args.weight_decay, [], 
            get_num_layer=self.assigner.get_layer_id if self.assigner is not None else None, 
            get_layer_scale=self.assigner.get_scale if self.assigner is not None else None)
            

        if self.args.opt == 'adamw':
            optimizer = torch.optim.AdamW(parameters, lr=self.args.lr, weight_decay=self.args.weight_decay)

        return optimizer

    def predict_step(self, batch, batch_idx):
        id_pd = []
        target_pd = []
        chunk_pd = []
        split_pd = []
        output_pd = []

        samples = batch[0]
        target = batch[1]
        ids = batch[2]
        chunk_nb = batch[3]
        split_nb = batch[4]
        batch_size = samples.shape[0]
        # samples = samples.to(device, non_blocking=True)
        # target = target.to(device, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast():
            output = self.model(samples)
            loss = self.loss_func(output, target)

        for i in range(output.size(0)):
            id_pd.append(ids[i])
            target_pd.append(int(target[i].cpu().numpy()))
            chunk_pd.append(int(chunk_nb[i].cpu().numpy()))
            split_pd.append(int(split_nb[i].cpu().numpy()))
            output_pd.append(output.data[i].cpu().numpy().tolist())
        
        df = pd.DataFrame({'ID' : id_pd,
        'Target' : target_pd,
        'Chunk' : chunk_pd,
        'Split' : split_pd,
        'Loss' : loss.item(),
        'Output' : output_pd})
        return df

def merge_final_result(dataFrame, args, save_path, path):

    IDs = []
    Targets = []
    Outputs = []
    loss = dataFrame['Loss'].mean()
    softmax = lambda x: np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum(axis=0)
    dataFrame = dataFrame.sort_values(by='ID', ascending=False)
    step = args.test_num_segment * args.test_num_crop
    for i in range(0, len(dataFrame), step):
        subset = dataFrame.iloc[i:i+step].reset_index(drop=True)
        output = np.array(subset['Output'].tolist())
        output = np.mean(output, axis = 0)
        Outputs.append(output)
        IDs.append(subset['ID'][0])
        Target = subset['Target'].unique()
        # print(Target)
        if len(Target) != 1:
            print("모든 'Target' 값이 같지 않습니다.")
            continue
        else:
            Targets.append(Target[0])

    Targets = torch.tensor(Targets)
    Outputs = torch.tensor(Outputs)
    Outputs = F.softmax(Outputs, dim=1)

    acc1 = accuracy(Outputs, Targets, task='MULTICLASS', num_classes = args.nb_classes)
    acc5 = accuracy(Outputs, Targets, task='MULTICLASS',  num_classes = args.nb_classes, top_k=5)

    path = path.split('.')[0]

    print(f'Test_acc1:          {acc1.item()}')
    print(f'Test_acc5:          {acc5.item()}')
    if os.path.exists(f'{save_path}/test.csv'):
        result_csv = pd.read_csv(f'{save_path}/test.csv')
        result_csv[f'Test_acc1_{path}'] = [acc1.item()]
        result_csv[f'Test_acc5_{path}'] = [acc5.item()]
        result_csv.to_csv(f'{save_path}/test.csv', index=False)    
    else:
        result = pd.DataFrame({
            f'Test_acc1_{path}': [acc1.item()],
            f'Test_acc5_{path}': [acc5.item()]})
        result.to_csv(f'{save_path}/test.csv', index=False)    